package com.gitee.jmash.rbac.client.shiro.grpc;

import com.gitee.jmash.common.grpc.GrpcContext;
import com.gitee.jmash.core.grpc.cdi.GrpcServerInterceptor;
import com.gitee.jmash.core.utils.AnnotationUtils;
import com.gitee.jmash.rbac.client.shiro.aop.GrpcAnnotationsAuthorizingMethodInterceptor;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import jakarta.annotation.Priority;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.shiro.aop.MethodInvocation;
import org.apache.shiro.authz.annotation.RequiresAuthentication;
import org.apache.shiro.authz.annotation.RequiresGuest;
import org.apache.shiro.authz.annotation.RequiresPermissions;
import org.apache.shiro.authz.annotation.RequiresRoles;
import org.apache.shiro.authz.annotation.RequiresUser;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.ThreadContext;

/**
 * 权限拦截
 * 
 * @author cgd
 *
 */
@GrpcServerInterceptor
@Priority(2000)
public class ShiroPermitInterceptor implements ServerInterceptor {

  private static final Log logger = LogFactory.getLog(ShiroPermitInterceptor.class);

  public static final String MARK = ":";

  @SuppressWarnings("unchecked")
  private static final Class<? extends Annotation>[] AUTHZ_ANNOTATION_CLASSES =
      new Class[] {RequiresPermissions.class, RequiresRoles.class, RequiresUser.class,
          RequiresGuest.class, RequiresAuthentication.class};

  @SuppressWarnings({"unchecked", "rawtypes"})
  @Override
  public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
      ServerCallHandler<ReqT, RespT> next) {
    try {
      Subject subject = (Subject) GrpcContext.USER_SUBJECT.get();
      // 不能删除
      ThreadContext.unbindSubject();
      ThreadContext.bind(subject);
      // 租户如何控制?.
      Object impl = getImplObject(next);
      Method method = getMethod(impl, call);
      // Shiro 注解权限控制
      if (isMatchAnnotation(method, impl.getClass())) {
        // RequiresGuest 不控制
        if (RequiresGuest.class.equals(getAnnotation(method, impl.getClass()).annotationType())) {
          return next.startCall(call, headers);
        }
        try {
          GrpcAnnotationsAuthorizingMethodInterceptor inter =
              new GrpcAnnotationsAuthorizingMethodInterceptor();
          return (Listener) inter.invoke(new MethodInvocation() {
            @Override
            public Object proceed() throws Throwable {
              return next.startCall(call, headers);
            }

            @Override
            public Method getMethod() {
              return method;
            }

            @Override
            public Object[] getArguments() {
              return null;
            }

            @Override
            public Object getThis() {
              return impl;
            }
          });
        } catch (Throwable ex) {
          logger.error("403 Forbidden：" + call.getMethodDescriptor().getFullMethodName());
          // logger.error("", ex);
          Status status =
              Status.PERMISSION_DENIED.withDescription(" 403 Forbidden Client Connection. "
                  + call.getMethodDescriptor().getFullMethodName());
          call.close(status, headers);
          return new ServerCall.Listener<ReqT>() {
            // noop
          };
        }
      } else {
        return next.startCall(call, headers);
      }
    } catch (Throwable ex) {
      logger.error("Permit Exception ：" + call.getMethodDescriptor().getFullMethodName());
      logger.error("", ex);
      return next.startCall(call, headers);
    }
  }

  /** 是否包含注解. */
  public Boolean isMatchAnnotation(Method m, Class<?> clazz) {
    return isAuthzAnnotationPresent(m) || isAuthzAnnotationPresent(clazz);
  }

  /** 类是否包含注解. */
  private boolean isAuthzAnnotationPresent(Class<?> targetClazz) {
    for (Class<? extends Annotation> annClass : AUTHZ_ANNOTATION_CLASSES) {
      Annotation a = AnnotationUtils.findAnnotation(targetClazz, annClass);
      if (a != null) {
        return true;
      }
    }
    return false;
  }

  /** 方法是否包含注解. */
  private boolean isAuthzAnnotationPresent(Method method) {
    for (Class<? extends Annotation> annClass : AUTHZ_ANNOTATION_CLASSES) {
      Annotation a = AnnotationUtils.findAnnotation(method, annClass);
      if (a != null) {
        return true;
      }
    }
    return false;
  }

  /** 包含注解. */
  public Annotation getAnnotation(Method m, Class<?> clazz) {
    Annotation a = getAuthzAnnotationPresent(m);
    if (a != null) {
      return a;
    }
    return getAuthzAnnotationPresent(clazz);
  }

  /** 类包含注解. */
  private Annotation getAuthzAnnotationPresent(Class<?> targetClazz) {
    for (Class<? extends Annotation> annClass : AUTHZ_ANNOTATION_CLASSES) {
      Annotation a = AnnotationUtils.findAnnotation(targetClazz, annClass);
      if (a != null) {
        return a;
      }
    }
    return null;
  }

  /** 方法包含注解. */
  private Annotation getAuthzAnnotationPresent(Method method) {
    for (Class<? extends Annotation> annClass : AUTHZ_ANNOTATION_CLASSES) {
      Annotation a = AnnotationUtils.findAnnotation(method, annClass);
      if (a != null) {
        return a;
      }
    }
    return null;
  }

  /** 获取方法. */
  public Method getMethod(Object impl, ServerCall<?, ?> call) {
    try {
      String fullMethodName = call.getMethodDescriptor().getFullMethodName();
      String methodName = fullMethodName.split("/")[1];
      Method[] methods = impl.getClass().getMethods();
      for (Method m : methods) {
        if (StringUtils.equalsIgnoreCase(methodName, m.getName())) {
          return m;
        }
      }
    } catch (Exception ex) {
      logger.error("", ex);
    }
    return null;
  }

  /** 获取实现对象. */
  public Object getImplObject(ServerCallHandler<?, ?> next) {
    try {
      Object callHandler = getFieldValue(next, "callHandler");
      return getImplObject((ServerCallHandler<?, ?>) callHandler);
    } catch (Exception ex) {
      try {
        Object method = getFieldValue(next, "method");
        Object serviceImpl = getFieldValue(method, "serviceImpl");
        return serviceImpl;
      } catch (Exception ex1) {
        logger.error("", ex1);
      }
    }
    return null;
  }

  /** 获取对象私有属性值. */
  public Object getFieldValue(Object obj, String fieldName) throws NoSuchFieldException,
      SecurityException, IllegalArgumentException, IllegalAccessException {
    Field field = obj.getClass().getDeclaredField(fieldName);
    field.setAccessible(true);
    Object value = field.get(obj);
    return value;
  }

}
